{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from loguru import logger\n",
    "from sklearn import datasets\n",
    "from sklearn.ensemble import RandomForestRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the Dataset\n",
    "\n",
    "Define our response `y` and predictor, `X`, variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()\n",
    "dataset = pd.DataFrame(data=iris[\"data\"], columns=iris[\"feature_names\"])\n",
    "X = dataset.iloc[:, 1:]\n",
    "y = dataset.iloc[:, 0]\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit a Random Forest Regressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = RandomForestRegressor(n_estimators=20, bootstrap=True)\n",
    "model.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create our ALE Plots\n",
    "\n",
    "### 1D Main Effect ALE Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "\n",
    "from alepython import ale_plot\n",
    "\n",
    "mpl.rc(\"figure\", figsize=(9, 6))\n",
    "ale_plot(\n",
    "    model,\n",
    "    X,\n",
    "    X.columns[:1],\n",
    "    bins=20,\n",
    "    monte_carlo=True,\n",
    "    monte_carlo_rep=100,\n",
    "    monte_carlo_ratio=0.6,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2D Second-Order ALE Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rc(\"figure\", figsize=(9, 6))\n",
    "ale_plot(model, X, X.columns[:2], bins=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rc(\"figure\", figsize=(9, 6))\n",
    "ale_plot(model, X, X.columns[1:], bins=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
